Sys.setenv(OPENAI_API_KEY = # INSERT KEY HERE
)

# Helper Functions ----------------------------------------------------------------------------------------------------------------

# --- 1. Safe JSON parsing ---
safe_fromJSON <- safely(function(x) {
  fromJSON(x, simplifyVector = FALSE)
})

# --- 2. Function to create JSON for a single (group_id, pair_id) discussion ---
#     We store only that discussion’s lines.
create_discussion_json <- function(data) {
  # data is already filtered to exactly one (group_id, pair_id) combo
  # We can store it as a list of utterances in JSON
  # Each row is one utterance
  discussion_list <- data %>%
    mutate(
      who_speaking_label = str_extract(who_speaking_label, "\\d+")
    ) %>%
    pmap(function(...) {
      curr <- list(...)
      list(
        transcript_line_id = curr$transcript_line_id,
        who_speaking_label = curr$who_speaking_label,
        speech_english     = curr$speech_english,
        pro_a_b      = if(is.na(curr$pro_a_b)) NULL else curr$pro_a_b
      )
    })

  # Convert to JSON (class "json") and then to a plain character
  as.character(
    toJSON(list(discussion_list), auto_unbox = TRUE, pretty = TRUE, null = "null")
  )
}

# --- 3. Prompt Generator: Make a prompt that compares discussion1 vs. discussion2 ---
gen_hypotheses_prompt <- function(discussion_json_1, discussion_json_2) {
  str_glue("
    BACKGROUND:
    Below are transcripts from 2 group discussions in Chennai, India, in which 3 participants discussed who they would prefer to hire to deliver groceries to their home. Participants had to choose between option A or B.
    They were shown photos of the two delivery workers, one of whom was male, and one of whom was transgender.
    Each discussion discussed multiple sets of options, changing the choice of Option A and Option B each time.
    The grocery items on offer were: Aachi masala spice, tea powder, and ghee. Option A and B may have offered the same set of items, or different sets of items.
    They were also in some cases given information about:
     - the languages spoken by the delivery workers (Tamil only, Tamil and English)
     - the delivery workers' experience
     - how many deliveries they completed in a training task
    Participants were asked to discuss which option they prefer, and why, and then make a collective choice between the two options.

    TRANSCRIPTS:
       The transcripts are in JSON format.
       Each discussion is represented as a list of utterances, with each utterance containing the following fields:
       - transcript_line_id: the order of the utterance in the discussion
         - who_speaking_label: the label of the participant speaking (1, 2, or 3, or unknown)
         - speech_english: the English translation of the participant's speech
         - pro_a_b: a column manually coded by research assistants indicate whether the participant is arguing to choose A or B, or neither

    TRANSCRIPT 1:
      {discussion_json_1}

    TRANSCRIPT 2:
      {discussion_json_2}

      TASK:
        Your task is to identify what has changed from TRANSCRIPT 1 to TRANSCRIPT 2. Focus on the generalizable insight that can be applied in other contexts. Ignore things that are specific to these transcripts. Do not make references to these transcripts that may not be relevant for others.
        Come up with an insight that captures the sort of change observed moving from Group 1 to Group 2.

      Come up with an insight as a single sentence in this exact format:
      Hypothesis: _____ is the main difference between Group 2 compared to Group 1.

        Please make sure that the hypothesis is:
        i. clear (i.e., precise, not too wordy, and easy to understand);
        ii. generalizable to novel situations (i.e., they would make sense if applied to other transcripts);
        iii. empirically plausible (i.e., this is a dimension on which messages can vary on);
        iv. unidimensional (i.e., avoid hypotheses that list multiple constructs so if there are many things changing, pick one);
        v. usable (i.e., a human equipped with this insight could evaluate another group discussion in a similar way)
    "
  )
}

# --- 4. Generate random temperature in [0.1, 0.9] ---
random_temperature <- function(min_temp = 0.1, max_temp = 0.9) {
  runif(1, min_temp, max_temp)
}

# --- 5. GPT call to generate a hypothesis ---
#     We store the temperature used in the output.
generate_hypothesis <- function(prompt,
                                model = "gpt-4o-mini",
                                verbose = TRUE,
                                min_temp = 0.1,
                                max_temp = 0.9) {
  temp_used <- random_temperature(min_temp, max_temp)
  if (verbose) {
    cat("[generate_hypothesis] Using temperature =", temp_used, "\n")
  }
  out_content <- NULL
  err_msg     <- NULL

  # Attempt the call
  tryCatch({
    completion <- create_chat_completion(
      model    = model,
      messages = list(list("role" = "user", "content" = prompt)),
      temperature = temp_used
    )
    out_content <- completion$choices$message.content
  }, error = function(e) {
    err_msg <- e$message
    warning("Error in GPT API call for hypothesis generation: ", e$message)
  })

  list(hypothesis_text = out_content,
       temperature     = temp_used,
       error           = err_msg)
}

# --- 5B. GPT call to "clean" the hypothesis, removing references to Group 1 and Group 2 ---
clean_hypothesis <- function(hypothesis_text,
                             model = "gpt-4o-mini",
                             verbose = FALSE) {

  clean_prompt <- glue::glue("
    You are an expert editor.
    Task: Remove references to 'Group 1' and 'Group 2' from the following hypothesis and rephrase it so that it stands on its own without mentioning groups.
    Provide only the cleaned version as a string, nothing else.

    EXAMPLES:
    'Hypothesis: The preference for delivery workers in Group 2 is influenced by the perceived convenience and experience associated with female delivery workers compared to other gender identities in Group 1.' --> 'The perception of convenience and experience associated with female delivery workers compared to other gender identities'
    'Hypothesis: Increased support for transgender workers in hiring decisions is the main difference between Group 2 compared to Group 1.' --> 'Support for transgender workers in hiring decisions.'

    Original Hypothesis:
    \"{hypothesis_text}\"
  ")

  out_cleaned <- NULL
  err_msg     <- NULL

  tryCatch({
    completion <- create_chat_completion(
      model    = model,
      messages = list(list("role" = "user", "content" = clean_prompt)),
      temperature = 0.2
    )
    out_cleaned <- completion$choices$message.content %>%
      str_replace_all("Hypothesis\\:|\\\\n|\"", "")
  }, error = function(e) {
    err_msg <- e$message
    warning("[clean_hypothesis] Error in GPT API call: ", e$message)
  })

  list(hypothesis_cleaned = out_cleaned %||% "",
       error              = err_msg %||% NA_character_)
}



# --- 6. Generate the rating prompt (for a single discussion vs. a single hypothesis) ---
generate_rating_prompt <- function(discussion_json, hypothesis) {
  str_glue("
    BACKGROUND:
    Below is a transcript from a group discussion in Chennai, India, in which 3 participants discussed who they would prefer to hire to deliver groceries to their home. Participants had to choose between option A or B.
    They were shown photos of the two delivery workers, one of whom was male, and one of whom was either male, female, or transgender.
    Each discussion discussed multiple sets of options, changing the choice of Option A and Option B each time.
    The grocery items on offer were: Aachi masala spice, tea powder, and ghee. Option A and B may have offered the same set of items, or different sets of items.
    They were also in some cases given information about:
     - the languages spoken by the delivery workers (Tamil only, Tamil and English)
     - the delivery workers' experience
     - how many deliveries they completed in a training task
    Participants were asked to discuss which option they prefer, and why, and then make a collective choice between the two options.

    TRANSCRIPT:
      {discussion_json}

    HYPOTHESIS:
      {hypothesis}

    TASK:
      Rate how much this single transcript aligns with the description of the hypothesis on a scale from 1 to 10, where:
        - 10 = maximally in line with the hypothesis
        - 1  = not at all in line with the hypothesis

      Output in JSON with two keys:
        - explanation (string)
        - rating (integer, 1-10 or return NA for the rating if the transcript does not contain enough information to make a rating.)
  ")
}

# --- 7. Parse the GPT rating response (JSON) ---
parse_rating_response <- function(response) {
  clean_response <- gsub("```(json)?", "", response)
  parsed <- safe_fromJSON(clean_response)
  if (is.null(parsed$result)) {
    return(list(explanation = NA_character_, rating = NA_real_))
  }

  out <- tryCatch({
    rating_val <- as.character(parsed$result$rating)
    if (rating_val == "NA" || is.null(rating_val)) {
      rating_num <- NA_real_
    } else {
      rating_num <- suppressWarnings(as.numeric(rating_val))
      if (is.na(rating_num)) rating_num <- NA_real_
    }
    explanation_val <- if (is.null(parsed$result$explanation)) NA_character_ else parsed$result$explanation
    list(explanation = explanation_val, rating = rating_num)
  }, error = function(e) {
    list(explanation = NA_character_, rating = NA_real_)
  })

  out
}

# --- 8. GPT call to generate a rating for one discussion/hypothesis ---
generate_rating <- function(prompt,
                            model = "gpt-4o-mini",
                            verbose = FALSE,
                            min_temp = 0.1,
                            max_temp = 0.9) {
  temp_used <- random_temperature(min_temp, max_temp)
  # if (verbose) {
  #   cat("[generate_rating] Using temperature =", temp_used, "")
  # }
  res_out   <- NULL
  err_msg   <- NULL

  # Safe API call
  tryCatch({
    completion <- create_chat_completion(
      model    = model,
      messages = list(list("role" = "user", "content" = prompt)),
      temperature = temp_used
    )
    res_out <- completion$choices$message.content
  }, error = function(e) {
    err_msg <- e$message
    warning("Error in GPT API call for rating: ", e$message)
  })

  rating_info <- parse_rating_response(if (is.null(res_out)) "" else res_out)
  list(
    explanation  = rating_info$explanation,
    rating       = rating_info$rating,
    temperature  = temp_used,
    error        = err_msg
  )
}

plot_correlation_results <- function(corr_data,
                                     wrap_width = 40,
                                     base_size = 13) {
  library(ggplot2)
  library(dplyr)
  library(stringr)

  # Step 1: Format hypothesis label and assign colors
  corr_data <- corr_data %>%
    mutate(
      hypothesis = str_replace(hypothesis, "Hypothesis\\:", ""),
      # Create a wrapped label for the hypothesis text
      hypothesis_wrapped = str_wrap(paste0(hypothesis_id, ": ", hypothesis), width = wrap_width),
      color = case_when(
        rating_type == "rating_trans"     ~ "indianred",
        rating_type == "rating_non_trans" ~ "skyblue",
        rating_type == "rating_overall"   ~ "gray",
        TRUE                              ~ "black"
      )
    ) %>%
    arrange(hypothesis_id, rating_type) %>%
    # Ensure the hypotheses are ordered properly on the y-axis
    mutate(hypothesis_wrapped = factor(hypothesis_wrapped, levels = unique(hypothesis_wrapped)))

  # Step 2: Add a small horizontal offset to avoid overlap
  corr_data <- corr_data %>%
    mutate(
      x_offset = case_when(
        rating_type == "rating_trans"     ~ -0.01,
        rating_type == "rating_non_trans" ~ 0.01,
        TRUE                              ~ 0
      )
    )

  # Step 3: Create the plot
  gg <- ggplot(corr_data, aes(
    x = correlation,  # Apply offset
    y = hypothesis_wrapped,
    color = color
  )) +
    geom_point(size = 3, position = position_dodge(0.2)) +  # Points for correlations
    geom_errorbarh(aes(xmin = conf_low, xmax = conf_high),
                   position = position_dodge(0.2),
                   height = 0.2, size = 0.8) +  # Horizontal error bars
    scale_color_identity() +  # Use the assigned colors directly
    geom_vline(xintercept = 0, color = "gray", linetype = "dashed") +  # Zero line
    theme_minimal(base_size = base_size) +
    labs(
      x = "Correlation with r2_choose_trans",
      y = "Hypothesis",
      title = "Correlation Results with Confidence Intervals",
      color = "Rating Type"
    ) +
    theme(
      axis.text.y = element_text(size = 11),  # Adjust y-axis text size
      legend.position = "none"               # Hide the legend
    )

  return(gg)
}


# MASTER FUNCTIONS FOR HYPOTHESIS GENERATION & RATING (with partial saving) ---------

# --- 1) Prepare the discussion-level data ---
#     Convert each (group_id, pair_id) to a JSON so we can feed each discussion individually to GPT.
prepare_discussion_data <- function(full_data,
                                    output_csv = "data/ml_hypothesis_generation/discussion_level_data.csv") {
  # If the file already exists, read it and skip re-building
  if (file.exists(output_csv)) {
    cat("[prepare_discussion_data] Loading existing discussion-level data from disk...")
    discussion_tbl <- read_csv(output_csv, show_col_types = FALSE)
    return(discussion_tbl)
  }

  cat("[prepare_discussion_data] Creating discussion-level data...")
  # We no longer filter pair_includes_trans == 1; we keep ALL discussions.
  # Summarize by (group_id, pair_id) => create JSON
  discussion_level_data <- full_data %>%
    group_by(group_id, pair_id) %>%
    mutate(transcript_line_id = row_number()) %>%
    group_split() %>%
    map_dfr(~ {
      # each chunk is a single discussion
      g   <- .x$group_id[1]
      p   <- .x$pair_id[1]
      jss <- create_discussion_json(.x)
      tibble(group_id = g, pair_id = p, discussion_json = jss)
    })

  # Save to disk
  dir.create(dirname(output_csv), showWarnings = FALSE, recursive = TRUE)
  write_csv(discussion_level_data, output_csv)
  discussion_level_data
}

# --- 2) Generate multiple hypotheses from pairs of distinct discussions ---
#     We'll randomly pick n discussion pairs, then call GPT to produce a hypothesis for each pair.
#     - If there's an existing CSV with these hypotheses, we skip re-generation.
generate_discussion_hypotheses <- function(discussion_tbl,
                                           n_pairs              = 5,
                                           model                = "gpt-4o-mini",
                                           verbose              = FALSE,
                                           min_temp             = 0.1,
                                           max_temp             = 0.9,
                                           output_csv           = "data/ml_hypothesis_generation/hypotheses_tbl.csv") {

  # If there's an existing file with columns: pair_id_1, pair_id_2, group_id_1, group_id_2, etc.,
  # then we skip regeneration.
  if (file.exists(output_csv)) {
    cat("[generate_discussion_hypotheses] Loading existing hypotheses from disk...\n")
    out_tbl <- readr::read_csv(output_csv, show_col_types = FALSE)
    return(out_tbl)
  }

  cat("[generate_discussion_hypotheses] Generating new hypotheses...\n")

  total_discussions <- nrow(discussion_tbl)
  if (total_discussions < 2) {
    stop("Not enough discussions to form distinct pairs.")
  }

  # -------------------------------------------------------------------------
  # Step A: Sample n_pairs pairs of *distinct* indices, but do so with replacement overall.
  #         This means for each pair we pick 2 distinct row indices, but across pairs we allow repeats.
  # -------------------------------------------------------------------------
  # Each element of pairs_list is a length-2 vector (e.g. c(12, 47))
  # so we never compare the same row with itself.
  set.seed(123) # optional if you want reproducible sampling
  pairs_list <- replicate(
    n_pairs,
    sample(seq_len(total_discussions), 2, replace = FALSE),
    simplify = FALSE
  )

  results_list <- vector("list", length = n_pairs)

  for (i in seq_len(n_pairs)) {
    pair_idxs <- pairs_list[[i]]  # a vector of length 2, e.g. c(5, 11)
    row_1     <- discussion_tbl[pair_idxs[1], ]
    row_2     <- discussion_tbl[pair_idxs[2], ]

    djson1    <- row_1$discussion_json
    djson2    <- row_2$discussion_json

    # Create the prompt
    prompt   <- gen_hypotheses_prompt(djson1, djson2)

    # Generate raw hypothesis
    hyp_call <- generate_hypothesis(prompt, model = model, verbose = verbose,
                                    min_temp = min_temp, max_temp = max_temp)

    # (Optional) If you use a cleaning step, you can call it here:
    clean_call <- clean_hypothesis(
      hypothesis_text = hyp_call$hypothesis_text,
      model           = model,
      verbose         = verbose
    )

    # Store results
    results_list[[i]] <- tibble::tibble(
      hypothesis_id      = i,
      group_id_1         = row_1$group_id,
      pair_id_1          = row_1$pair_id,
      group_id_2         = row_2$group_id,
      pair_id_2          = row_2$pair_id,
      hypothesis         = hyp_call$hypothesis_text,
      hypothesis_cleaned = clean_call$hypothesis_cleaned,
      temperature        = hyp_call$temperature,
      error              = hyp_call$error %||% NA_character_,
      prompt             = prompt
    )
  }

  # Combine all results into a single data frame
  out_tbl <- dplyr::bind_rows(results_list)

  # Save to disk
  readr::write_csv(out_tbl, output_csv)
  out_tbl
}


# --- 3) Rate each discussion against each hypothesis ---
#     We do "repeats_per_discussion" calls to GPT for each (discussion, hypothesis).
#     - After each discussion is processed for a given hypothesis, we save partial results to disk
#       so we can resume if needed.
rate_discussions_for_hypotheses <- function(discussion_tbl,
                                            hypotheses_tbl,
                                            repeats_per_discussion = 1,
                                            model                  = "gpt-4o-mini",
                                            verbose                = FALSE,
                                            min_temp               = 0.1,
                                            max_temp               = 0.9,
                                            output_csv             = "data/ml_hypothesis_generation/rating_results_long.csv") {
  # If output_csv exists, load it for partial or complete results
  if (file.exists(output_csv)) {
    cat("[rate_discussions_for_hypotheses] Found existing rating_results_long file. Loading...")
    rating_results_existing <- read_csv(output_csv, show_col_types = FALSE)
  } else {
    rating_results_existing <- tibble(
      hypothesis_id = integer(),
      group_id      = character(),
      pair_id       = integer(),
      iteration     = integer(),
      rating_val    = double(),
      rating_expl   = character(),
      temperature   = double(),
      error         = character()
    )  # empty
  }

  # We'll store new results into a list, then bind to the old results at the end
  new_results <- list()
  new_counter <- 1

  # We'll create a key that identifies (hypothesis_id, group_id, pair_id, iteration)
  # so we can skip if it's already done.
  all_rows <- crossing(
    hypotheses_tbl %>% select(hypothesis_id, hypothesis_cleaned),
    discussion_tbl %>% select(group_id, pair_id),
    iteration_within_disc = seq_len(repeats_per_discussion)
  )

  # We iterate row by row
  total_calls <- nrow(all_rows)
  cat("[rate_discussions_for_hypotheses] Need to process", total_calls, "discussion-hypothesis combos.")
  skip_from <- 514000

  for (idx in seq_len(total_calls)) {
    if (idx < skip_from) {
      next
    }
    row_info       <- all_rows[idx, ]
    h_id           <- row_info$hypothesis_id
    h_text         <- row_info$hypothesis_cleaned
    g_id           <- row_info$group_id
    p_id           <- row_info$pair_id
    iteration_this <- row_info$iteration_within_disc

    # Check if we already have a rating for (h_id, group_id, pair_id, iteration_this)
    already_done <- rating_results_existing %>%
      filter(hypothesis_id == h_id,
             group_id     == g_id,
             pair_id      == p_id,
             iteration    == iteration_this) %>%
      filter(!is.na(rating_expl) & rating_expl != "" & rating_expl != "NA")

    if (nrow(already_done) > 0) {
      if (verbose) {
        if (idx %% 1000 == 0) {
          cat(sprintf("[rate_discussions_for_hypotheses] SKIP Row %d / %d, Processing h=%d, group_id=%s, pair_id=%d, iteration=%d \n",
                      idx, total_calls, h_id, g_id, p_id, iteration_this))
        }

      }
      next
    }

    # If not done, do it now
    if (verbose) {
      cat(sprintf("[rate_discussions_for_hypotheses] Row %d / %d, Processing h=%d, group_id=%s, pair_id=%d, iteration=%d \n",
                  idx, total_calls, h_id, g_id, p_id, iteration_this))
    }

    # Prepare rating prompt
    disc_json <- discussion_tbl %>%
      filter(group_id == g_id, pair_id == p_id) %>%
      pull(discussion_json)

    if (length(disc_json) < 1) {
      warning("No discussion JSON found for group_id=", g_id, ", pair_id=", p_id)
      next
    }
    rating_prompt <- generate_rating_prompt(disc_json, h_text)
    rating_call   <- generate_rating(
      prompt   = rating_prompt,
      model    = model,
      verbose  = verbose,
      min_temp = min_temp,
      max_temp = max_temp
    )

    new_results[[new_counter]] <- tibble(
      hypothesis_id = h_id,
      group_id      = g_id,
      pair_id       = p_id,
      iteration     = iteration_this,
      rating_val    = rating_call$rating,
      rating_expl   = rating_call$explanation,
      temperature   = rating_call$temperature,
      error         = rating_call$error %||% NA_character_
    )
    new_counter <- new_counter + 1

    # Save partial results to disk every time we process a single discussion for a single hypothesis
    # Combine old and new so far
    partial_df <- bind_rows(rating_results_existing, bind_rows(new_results))
    write_csv(partial_df, output_csv)
  }

  # At the end, read the final file from disk (in case we wrote partial updates)
  rating_results_long <- read_csv(output_csv, show_col_types = FALSE)
  rating_results_long
}

# --- 4) Aggregate ratings to group level (for correlation with group-level outcome) ---
#     We'll compute the average rating specifically for trans vs. non-trans discussions,
#     as well as an overall average.
#     Then we can merge with `r2_choices_group` at the group level if needed.
aggregate_discussion_ratings <- function(rating_results_long,
                                         discuss_obs,  # must contain group_obs_trans_included, etc.
                                         output_csv = "data/ml_hypothesis_generation/group_level_ratings.csv") {
  # If file exists, skip
  if (file.exists(output_csv)) {
    cat("[aggregate_discussion_ratings] Loading existing group-level ratings from disk...")
    aggregated_data <- read_csv(output_csv, show_col_types = FALSE)
    return(aggregated_data)
  }

  cat("[aggregate_discussion_ratings] Aggregating discussion ratings to group level...")

  # Step A: join to a table that indicates which (group_id, pair_id) includes a trans worker
  # Suppose discuss_obs has columns: group_id, pair_id, group_obs_trans_included (TRUE/FALSE)
  rating_plus_obs <- rating_results_long %>%
    left_join(
      discuss_obs %>%
        select(group_id, pair_id, group_obs_trans_included),
      by = c("group_id", "pair_id")
    )

  # Step B: Summarize at group-hypothesis level first: average rating across all iterations
  #         Then we can compute "mean rating (trans discussions)" and "mean rating (non-trans)".
  group_hyp <- rating_plus_obs %>%
    group_by(group_id, hypothesis_id, group_obs_trans_included) %>%
    summarise(rating_mean = mean(rating_val, na.rm = TRUE), .groups = "drop")

  # Step C: Pivot so we get separate columns for trans vs. non-trans
  group_hyp_wide <- group_hyp %>%
    pivot_wider(
      id_cols     = c(group_id, hypothesis_id),
      names_from  = group_obs_trans_included,
      values_from = rating_mean,
      names_prefix = "rating_"
    ) %>%
    rename(
      rating_trans     = "rating_TRUE",
      rating_non_trans = "rating_FALSE"
    ) %>%
    mutate(
      rating_overall = rowMeans(select(., starts_with("rating_")), na.rm = TRUE)
    )

  # Optional: you might prefer an even further aggregation so you end up with group-level columns
  # We then have columns for each hypothesis_id. That’s up to you.
  # For simplicity, we’ll keep a “long” format in the sense of:
  #   group_id, hypothesis_id, rating_trans, rating_non_trans, rating_overall
  # ...and save that.

  write_csv(group_hyp_wide, output_csv)
  group_hyp_wide
}

# --- 5) Merge with group-level outcome, compute correlations if desired ---
merge_with_outcome_and_correlate <- function(group_hyp_wide,
                                             hypotheses_tbl,
                                             r2_choices_group,
                                             output_csv = "data/ml_hypothesis_generation/correlations.csv") {
  # If file exists, skip
  if (file.exists(output_csv)) {
    cat("[merge_with_outcome_and_correlate] Found existing correlation results file. Loading...\n")
    corr_out <- read_csv(output_csv, show_col_types = FALSE)
    return(corr_out)
  }

  cat("[merge_with_outcome_and_correlate] Merging with group-level outcome and computing correlations...\n")

  # Merge
  merged <- group_hyp_wide %>%
    left_join(r2_choices_group, by = "group_id")

  # We'll correlate rating_trans, rating_non_trans, rating_overall
  # with r2_choose_trans for each hypothesis_id
  rating_cols <- c("rating_trans", "rating_non_trans", "rating_overall")
  out_list    <- list()

  # For each hypothesis:
  for (h in unique(merged$hypothesis_id)) {
    subdata  <- merged %>% filter(hypothesis_id == h)

    # Get the text from hypotheses_tbl (may have duplicates if you have multiple rows)
    h_text <- hypotheses_tbl %>%
      filter(hypothesis_id == h) %>%
      pull(hypothesis_cleaned) %>%
      first()

    for (rc in rating_cols) {
      x <- subdata[[rc]]
      y <- subdata[["r2_choose_trans"]]

      # Remove NAs
      df <- data.frame(x = x, y = y)
      df <- df[complete.cases(df), ]

      # If too few non-NA, skip
      if (nrow(df) < 3) {
        out_list[[length(out_list) + 1]] <- tibble(
          hypothesis_id = h,
          hypothesis    = h_text,
          rating_type   = rc,
          correlation   = NA_real_,
          p_value       = NA_real_,
          conf_low      = NA_real_,
          conf_high     = NA_real_
        )
      } else {
        # Attempt cor.test
        ct <- tryCatch(stats::cor.test(df$x, df$y), error = function(e) NULL)
        if (is.null(ct)) {
          out_list[[length(out_list) + 1]] <- tibble(
            hypothesis_id = h,
            hypothesis    = h_text,
            rating_type   = rc,
            correlation   = NA_real_,
            p_value       = NA_real_,
            conf_low      = NA_real_,
            conf_high     = NA_real_
          )
        } else {
          out_list[[length(out_list) + 1]] <- tibble(
            hypothesis_id = h,
            hypothesis    = h_text,
            rating_type   = rc,
            correlation   = ct$estimate[[1]],   # correlation coefficient
            p_value       = ct$p.value,         # p-value
            conf_low      = ct$conf.int[1],     # lower confidence limit
            conf_high     = ct$conf.int[2]      # upper confidence limit
          )
        }
      }
    }
  }

  corr_out <- bind_rows(out_list)

  # Save to disk
  write_csv(corr_out, output_csv)
  corr_out
}


# EXAMPLE USAGE / MAIN SCRIPT ------------------------------------------

# (Below is a minimal example. Adjust file paths / variable names to your real dataset.)

# 1) Load your raw data containing (group_id, pair_id) discussions. No filter for trans or not.
transcripts_with_embeddings <- read_csv("data/cleaned/transcripts_with_embeddings.csv") %>%
  mutate(speech_english = quote)

r2_choices_group <- r2_choices_num %>%
  group_by(group_id) %>%
  summarise(
    r2_choose_trans = sum_na(r2_choose_trans)
  )

# 2) Suppose you have a data frame "discuss_obs" with columns: group_id, pair_id, group_obs_trans_included
#    and a data frame "r2_choices_group" with columns: group_id, r2_choose_trans, etc.

# 3) Prepare discussion data and create JSON for each (group_id, pair_id).
discussion_level_transcript_tbl <- prepare_discussion_data(transcripts_with_embeddings)

# 4) Generate some hypothesis pairs from the discussion-level data.
#    For demonstration, let's create 2 pairs of discussions => 2 hypotheses
set.seed(12345)
n_hypotheses <- 500

hypotheses_tbl <- generate_discussion_hypotheses(
  discussion_tbl = discussion_level_transcript_tbl,
  n_pairs  = n_hypotheses,
  model    = "gpt-4o-mini",
  verbose  = TRUE
)

# 5) Rate each discussion w.r.t. each hypothesis.


set.seed(12345)
test_only <- FALSE
if (test_only) {
  discussions_to_test <- 30
  discussion_level_transcript_tbl_test <- discussion_level_transcript_tbl %>% slice_sample(n = discussions_to_test)
} else {
  discussion_level_transcript_tbl_test <- discussion_level_transcript_tbl
}


#    For demonstration, let's do 2 repeats per discussion-hypothesis.
set.seed(12345)
n_ratings_per_row <- 1
rating_results_long <- rate_discussions_for_hypotheses(
  discussion_tbl = discussion_level_transcript_tbl_test,
  hypotheses_tbl = hypotheses_tbl,
  repeats_per_discussion = n_ratings_per_row,
  model    = "gpt-4o-mini",
  verbose  = TRUE
)

# discussion_level_transcript_tbl

# 6) Aggregate discussion-level ratings to group-level, separating trans vs. non-trans.
group_hyp_wide <- aggregate_discussion_ratings(
  rating_results_long,
  discuss_obs = discuss_obs %>% select(group_id, round, group_obs_trans_included) %>% mutate(pair_id = as.integer(round))  # must have group_id, pair_id, group_obs_trans_included
)

# 7) Merge with group-level outcome (r2_choose_trans, etc.) and do correlations
correlations <- merge_with_outcome_and_correlate(
  group_hyp_wide,
  hypotheses_tbl,
  r2_choices_group
)

plot_correlation_results(correlations)

# 8) Done! Inspect your final results / correlations.


# CHECK TOTAL COST:
# TOTAL = ~1000 discussions x 500 hypotheses x 1 ratings = ~0.5 million GPT calls